{% extends 'base.attached.class' %}

{% block content %}
{% if is_test or to_json %}
import java.util.Arrays;

{% endif %}
class {{ class_name }} {

    private double[] priors;
    private double[][] probs;

    public {{ class_name }}(double[] priors, double[][] probs) {
        this.priors = priors;
        this.probs = probs;
    }

    private int findMax(double[] nums) {
        int i = 0, l = nums.length, idx = 0;
        for (i = 0; i < l; i++) {
            idx = nums[i] > nums[idx] ? i : idx;
        }
        return idx;
    }

    private double logSumExp(double[] nums) {
        double max = nums[findMax(nums)];
        double sum = 0.;
        for (int i = 0 , il = nums.length; i < il; i++) {
            sum += Math.exp(nums[i] - max);
        }
        return max - Math.log(sum);
    }

    private double[] compute(double[] features) {
        int nClasses = this.probs.length;
        int nFeatures = this.probs[0].length;

        double[][] deltas = new double[nClasses][nFeatures];
        double[] jll = new double[nClasses];

        int i, j;
        double sum;

        for (j = 0; j < nFeatures; j++) {
            features[j] = features[j] > 0 ? 1 : 0;
        }
        for (i = 0; i < nClasses; i++) {
            for (j = 0; j < nFeatures; j++) {
                deltas[i][j] = Math.log(1 - Math.exp(this.probs[i][j]));
            }
        }
        for (i = 0; i < nClasses; i++) {
            jll[i] = 0.;
        }
        for (j = 0; j < nFeatures; j++) {
            for (i = 0; i < nClasses; i++) {
                jll[i] += features[j] * (this.probs[i][j] - deltas[i][j]);
            }
        }
        for (i = 0; i < nClasses; i++) {
            sum = 0.;
            for (j = 0; j < nFeatures; j++) {
                sum += deltas[i][j];
            }
            jll[i] = jll[i] + this.priors[i] + sum;
        }
        return jll;
    }

    public int predict(double[] features) {
        return findMax(compute(features));
    }

    public double[] predictProba(double[] features) {
        double[] jll = compute(features);
        double sum = logSumExp(jll);
        for (int i = 0, il = jll.length; i < il; i++) {
            jll[i] = Math.exp(jll[i] - sum);
        }
        return jll;
    }

    public static void main(String[] args) {

        int nFeatures = {{ n_features }};
        if (args.length != nFeatures) {
            throw new IllegalArgumentException("You have to pass " +  String.valueOf(nFeatures) + " features.");
        }

        // Features:
        double[] features = new double[args.length];
        for (int i = 0, l = args.length; i < l; i++) {
            features[i] = Double.parseDouble(args[i]);
        }

        // Model data:
        {{ priors }}
        {{ probs }}

        // Estimator:
        {{ class_name }} clf = new {{ class_name }}(priors, probs);

        {% if is_test or to_json %}
        // Get JSON:
        int prediction = clf.predict(features);
        double[] probabilities = clf.predictProba(features);
        System.out.println("{\"predict\": " + String.valueOf(prediction) + ", \"predict_proba\": " + String.join(",", Arrays.toString(probabilities)) + "}");
        {% else %}
        // Get class prediction:
        int prediction = clf.predict(features);
        System.out.println("Predicted class: #" + String.valueOf(prediction));

        // Get class probabilities:
        double[] probabilities = clf.predictProba(features);
        for (int i = 0; i < probabilities.length; i++) {
            System.out.println("Probability of class #" + i + " : " + String.valueOf(probabilities[i]));
        }
		    {% endif %}

    }
}
{% endblock %}